package main

import (
	"bytes"
	"flag"
	"fmt"
	"github.com/zeromicro/go-zero/tools/goctl/api/spec"
	"go-zero-dandan/common/fmtd"
	"go-zero-dandan/pkg/arrd"
	"go-zero-dandan/pkg/parsed"
	"os"
	"os/exec"
	"runtime"
	"strings"
	"text/template"
	"unicode"
)

var rpcName = flag.String("rpc", "", "rpc name")

const rootPath = "/Users/yelin/go_dev/project/src/go-zero-dandan"

func main() {
	flag.Parse()
	if *rpcName == "" {
		fmtd.Fatal("rpc name is empty")
	}
	rpcFile := fmt.Sprintf("%s/app/%s/rpc/%s.rpc", rootPath, *rpcName, *rpcName)

	//因为解析方法文件不存在不会报错直接阻断，检测下
	_, err := os.Stat(rpcFile)
	if err != nil {
		fmtd.Fatal(err)
	}
	//用gozero的goctl的api解析方法解析
	api, err := parsed.ParseGoZeroApiByFile(rpcFile)
	if err != nil {
		fmtd.Fatal(err)
	}
	genProto(*rpcName, api)
}
func runCmd(cmd string) error {
	fmt.Printf("run cmd : %s\n", cmd)
	var command *exec.Cmd
	if runtime.GOOS == "windows" {
		command = exec.Command("cmd", "/C", cmd)
	} else {
		command = exec.Command("/bin/sh", "-c", cmd)
	}

	command.Stdout = os.Stdout
	command.Stderr = os.Stderr
	return command.Run()
}

var temp = `// Code generated by goctl. DO NOT EDIT.
syntax = "proto3";
	
option go_package = "./{{.rpcName}}Rpc";

package {{.rpcName}};

message EmptyReq{}
{{.messages}}
	
{{.services}}
`

func genProto(rpcName string, api *spec.ApiSpec) {
	//解析上面模版文件，生成一个模版内容
	tmpl, err := template.New("proto").Parse(temp)
	if err != nil {
		fmtd.Fatal(err)
	}

	var result bytes.Buffer
	// 获取api的路由服务信息
	services, reqList := getServices(rpcName, api)
	// 将变量填充到模版内容上
	err = tmpl.Execute(&result, map[string]any{
		"rpcName":  rpcName,
		"messages": getMessages(api, reqList),
		"services": services,
	})
	if err != nil {
		fmtd.Fatal("Error executing template:", err)
	}
	//生成ptoro文件
	if err := os.WriteFile(fmt.Sprintf("%s/app/%s/rpc/%s.proto", rootPath, rpcName, rpcName), result.Bytes(), 0644); err != nil {
		fmtd.Fatal(err)
	}
	//这里必须着么打印，配合sh用来判断是否生成成功用的
	fmt.Println("gen proto success")
}
func getServices(rpcName string, api *spec.ApiSpec) (serviceStr string, reqList []string) {
	services := "service " + rpcName + " {\n"
	reqList = make([]string, 0)
	//遍历api的路由方法，处理入参、出参结构体
	for _, route := range api.Service.Routes() {
		if len(route.RequestTypeName()) > 0 {
			services += fmt.Sprintf("\trpc %s(%v)", route.Handler, route.RequestType.Name())
		} else {
			services += fmt.Sprintf("\trpc %s(EmptyReq)", route.Handler)
		}
		reqList = append(reqList, route.RequestType.Name())
		if len(route.ResponseTypeName()) > 0 {
			services += fmt.Sprintf(" returns (%v);\n", route.ResponseType.Name())
		} else {

		}

	}
	return services + "}\n", reqList
}
func getMessages(api *spec.ApiSpec, reqList []string) string {
	messages := ""
	//遍历所有定义的结构体
	for _, tp := range api.Types {
		obj, ok := tp.(spec.DefineStruct)
		if !ok {
			fmtd.Fatal("unspport struct type: "+tp.Name(), "unspport struct type: "+tp.Name())
		}
		//开始组装proto的定义
		messages += "message " + obj.Name() + " {\n"
		for i, field := range obj.Members {
			//字段都用小写开头
			fieldName := toFirstLower(field.Name)
			//判断是否为req入参
			isReq := arrd.Contain(reqList, obj.RawName)
			//获取字段属性，req都是optional ，切片和map都是repeated
			fieldType, fieldAttr := transReqType(field.Type.Name(), isReq)
			messages += fmt.Sprintf("\t%s %s %s = %d;\n", fieldAttr, fieldType, fieldName, i+1)
		}
		messages += "}\n"
		continue
	}
	return messages
}
func toFirstUpper(s string) string {
	if len(s) == 0 {
		return s
	}

	r := []rune(s)
	r[0] = unicode.ToUpper(r[0])
	return string(r)
}
func toFirstLower(s string) string {
	if len(s) == 0 {
		return s
	}

	r := []rune(s)
	r[0] = unicode.ToLower(r[0])
	return string(r)
}
func transReqType(typeName string, isReq bool) (fieldType, fieldAttr string) {
	if typeName == "" {
		fmtd.Fatal("typeName is empty")
	}
	isPt := false
	if typeName[:1] == "*" {
		typeName = typeName[1:]
		isPt = true
	}
	fieldType = typeName
	fieldAttr = "optional"
	// 如果不是req参数，并且没指定指针，就不用optioned
	if !isReq && !isPt {
		fieldAttr = ""
	}
	if len(typeName) > 3 && typeName[:3] == "[]*" {
		fieldAttr = "repeated"
		fieldType = typeName[3:]
	} else if typeName == "[]byte" {
		fieldAttr = ""
		fieldType = "bytes"
	} else if len(typeName) > 2 && typeName[:2] == "[]" {
		fieldAttr = "repeated"
		if typeName[2:3] == "*" {
			fieldType = typeName[3:]
		} else {
			fieldType = typeName[2:]
		}

	} else if len(typeName) > 4 && typeName[:4] == "map[" {
		fieldAttr = ""
		fieldType = convertGoMapToProto(typeName)
	}
	return
}

func convertGoMapToProto(goMapStr string) string {
	// Remove the "map[" and "]" parts
	trimmed := strings.TrimPrefix(strings.TrimSuffix(goMapStr, "]"), "map[")
	// Split the key and value types
	parts := strings.Split(trimmed, "]")

	// If the parts are not as expected, return an empty string or an error message
	if len(parts) != 2 {
		return "Invalid Go map type"
	}

	keyType := parts[0]
	valueType := parts[1]

	// Handle Go to Protobuf type conversion
	switch valueType {
	case "int":
		valueType = "int32"
	case "int64":
		valueType = "int64"
	case "string":
		valueType = "string"
	case "bool":
		valueType = "bool"
	// Add more type conversions as needed
	default:
		//把map里的星号去掉
		if valueType[:1] == "*" {
			valueType = valueType[1:]
		}
	}

	return fmt.Sprintf("map<%s, %s>", keyType, valueType)
}
